# -*- coding: utf-8 -*-
"""
Created on Sun Jan 17 13:19:01 2021

@author: 98330
"""
import gurobipy as gr
import numpy as np
import root_const
from root_const import ROOT_PATH
from base_utils.my_utils import load_json, save_json
from server_scheduling.server_switch_basic import ServerSwitchBasic
from min_cost_flow.min_cost_flow import GetAccurateFlow

class ServerSwitchDemandPredict(ServerSwitchBasic):
    
    def __init__(self, location_array, k_near, k_relative, lanti1, long1, server_placement, time_interval, predict_error_rate_maximum,
                 lanti_list, long_list, frequ_list, adjust_rate):
        """
        
        param
        ----------
            location_array : array
                the array to store both the lantitude and longtitude information of the location
                structure : nx2
                content : first column ---> lantitude , second column ---> longitude
            k_near : int
                the number of the near base station we will find
            compute_num : int
                the number of the servers we want to assign
            workload : list
                store the information about the load
        """
        ServerSwitchBasic.__init__(self, location_array, k_near)                  
        self.lanti1 = lanti1
        self.long1 = long1
        self.server_placement = server_placement
        self.time_interval = time_interval
        self.server_arrange = np.zeros((1,1))
        self.optimal_list = []
        self.predict_error_rate_maximum = predict_error_rate_maximum
        self.lanti_list = lanti_list
        self.long_list = long_list
        self.frequ_list = frequ_list
        self.adjust_rate = adjust_rate
        self.k_relative = k_relative
    
    def load_distribution(self, lanti_1, long_1, lanti_2, long_2, load):
        """find the locations in the loca2 that are emerged in the loca1 
        
        param
        -----------
            lanti_1 : list
                store all the lantitude information of the loca1
            lanti_2 : list
                store all the lantitude information of the loca2
            long_1 : list
                store all the longtitude information of the loca1
            long_2 : list
                store all the longtitude information of the loca2
            load : list
                store all the load information of the loca2
            
        return
        ----------
            load_rebuild : list
                store the load in both loca1 and loca2, others are 0
            num : int
                the number of the loca included in both loca1 and loca2
        
        refer
        ----------
            Other func : None
        """
        base_num = len(lanti_1)
        num = 0
        load_rebuild = [0 for i in range(base_num)]
        
        for i in range(len(lanti_2)):    
            try:
                lanti_index = lanti_1.index(lanti_2[i])
            except:
                lanti_index = -1
        
            try:
                long_index = long_1.index(long_2[i])
            except:
                long_index = -2
                
            if lanti_index == long_index:
                load_rebuild[lanti_index] = load[i]
                num = num + 1
        return load_rebuild, num
    
    def add_error(self, load_distribution):
        error_rate = self.predict_error_rate_maximum * np.random.uniform(-1, 1, (1,len(load_distribution)))
        total_rate = np.ones((1,len(load_distribution))) + error_rate
        predict_distribution = total_rate * np.array(load_distribution)
        return list(predict_distribution[0])
        
    def sum_distribution(self):
        """
        rebuild the data to change one day data into several part
        """
        final_rebuild = []
        final_rebuild_error = []
        
        for i in range(len(self.lanti_list)):
            load_rebuild, num = self.load_distribution(self.lanti1, self.long1, self.lanti_list[i], self.long_list[i], self.frequ_list[i])
            load_rebuild_error = self.add_error(load_rebuild)
            final_rebuild.append(list(np.array(load_rebuild) * self.adjust_rate))
            final_rebuild_error.append(list(np.array(load_rebuild_error) * self.adjust_rate))
        return final_rebuild, final_rebuild_error
    
    def find_relative_loca(self):
        union_relative = load_json(ROOT_PATH + '/all_data/server_scheduling_result/server_union_statistic/', 
                          'server_union.json')
        list_relative = []
        list_relative_assign = []
        for i in range(len(union_relative)):
            list_relative_assign.append([])    
        for i in range(len(union_relative)):
            dict_temp = union_relative[i]
            dict_sort = sorted(dict_temp.items(),key=lambda dict_temp:dict_temp[1],reverse=True)
            dict_key = [dict_sort[i][0] for i in range(len(dict_sort))]
            dict_key = list(map(int, dict_key))
            list_relative.append(dict_key[0:self.k_relative])
            for j in range(self.k_relative):
                list_relative_assign[list_relative[i][j]].append(i)
        return list_relative, list_relative_assign
    
    def server_usage_predict(self):
      
#        list_near, list_far, list_assign = self.location_divide()
        list_near, list_assign = self.find_relative_loca()
        final_rebuild2, final_rebuild = self.sum_distribution()
        time_part = len(final_rebuild)
        base_num = len(list_near)
        list_var = [] 
        
        model_placement = gr.Model('Lip_placement')
        y = model_placement.addVars(time_part, base_num, vtype='S', lb=0)
        a = model_placement.addVars(time_part, vtype='S', lb=0)  
        u = model_placement.addVars(time_part, base_num, len(list_near[0]), vtype='S', lb=0)                
        
        model_placement.setObjective(sum(a[i] for i in range(time_part)) , gr.GRB.MAXIMIZE)
        for i in range(time_part):
            model_placement.addConstrs(
                    y[i,j] <= self.server_placement[j] for j in range(base_num)
            ) 
            model_placement.addConstrs(
                    sum(u[i,k,j] for j in range(len(list_near[0]))) == a[i]
                    for k in range(base_num)
            )
        for k in range(time_part):
            model_placement.addConstrs(
                    sum(u[k,list_assign[i][j],list_near[list_assign[i][j]].index(i)] * final_rebuild[k][list_assign[i][j]] 
                    for j in range(len(list_assign[i]))) <= y[k,i] 
                    for i in range(base_num)
             ) 
        model_placement.optimize() 
       
        if model_placement.Status == gr.GRB.OPTIMAL:
            for var in model_placement.getVars():
                list_var.append(var.x)
            final_optimal = list_var[base_num * time_part : base_num * time_part + time_part]
        self.optimal_list = final_optimal.copy()
    
    def server_switch(self, max_rate, weight_switch, weight_await):
      
#        list_near, list_far, list_assign = self.location_divide()
        list_near, list_assign = self.find_relative_loca()
        final_rebuild2, final_rebuild = self.sum_distribution()
        time_part = len(final_rebuild)
        base_num = len(list_near)
        self.server_usage_predict()
        list_var = [] 
        
        model_placement = gr.Model('Lip_placement')
        y = model_placement.addVars(time_part, base_num, vtype='S', lb=0)
        energy_cost = model_placement.addVar(vtype='S', lb=0)  
        u = model_placement.addVars(time_part, base_num, len(list_near[0]), vtype='S', lb=0)                
        switch = model_placement.addVars(time_part - 1, base_num, vtype='S', lb=0)  
        a = 1 / max_rate
        a_list = []
        for i in range(time_part):
            if a <= self.optimal_list[i] - 0.2 * a:
                a_list.append(a)
            else:
                a_list.append(self.optimal_list[i] * 0.8)
            
        model_placement.setObjective(energy_cost , gr.GRB.MINIMIZE)
        for i in range(time_part):
            model_placement.addConstrs(
                    y[i,j] <= self.server_placement[j] for j in range(base_num)
            ) 
            model_placement.addConstrs(
                    sum(u[i,k,j] for j in range(len(list_near[0]))) == a_list[i]
                    for k in range(base_num)
            )
        for k in range(time_part):
            model_placement.addConstrs(
                    sum(u[k,list_assign[i][j],list_near[list_assign[i][j]].index(i)] * final_rebuild[k][list_assign[i][j]] 
                    for j in range(len(list_assign[i]))) <= y[k,i] 
                    for i in range(base_num)
             )
        for k in range(0,time_part-1):
            model_placement.addConstrs(
                    switch[k,j] >= y[k,j] - y[k+1,j] for j in range(base_num)
            ) 
            model_placement.addConstrs(
                    switch[k,j] >= y[k+1,j] - y[k,j] for j in range(base_num)
            ) 
        
        model_placement.addConstr(
                 energy_cost == weight_switch * sum(switch[i,j] for j in range(base_num) for i in range(time_part - 1))
                 + weight_await * sum(y[i,j] for j in range(base_num) for i in range(time_part))
        ) 
           
        model_placement.optimize() 
       
        if model_placement.Status == gr.GRB.OPTIMAL:
            for var in model_placement.getVars():
                list_var.append(var.x)
            result_list = np.array(list_var[0: base_num * time_part])
            result_list = result_list.reshape((time_part,base_num))
            final_optimal = list_var[base_num * time_part : base_num * time_part]
        else:
            print("no solve")
            return 0,0
        return result_list, final_optimal
    
    def server_switch_max_flow(self, max_rate, weight_switch, weight_await):
      
#        list_near, list_far, list_assign = self.location_divide()
        list_near, list_assign = self.find_relative_loca()
        final_rebuild2, final_rebuild = self.sum_distribution()
        time_part = len(final_rebuild)
        base_num = len(list_near)
        self.server_usage_predict()
        list_var = [] 
        
        model_placement = gr.Model('Lip_placement')
        y = model_placement.addVars(time_part, base_num, vtype='S', lb=0)
        energy_cost = model_placement.addVar(vtype='S', lb=0)  
        u = model_placement.addVars(time_part, base_num, len(list_near[0]), vtype='S', lb=0)                
        switch = model_placement.addVars(time_part - 1, base_num, vtype='S', lb=0)  
        a = 1 / max_rate
        a_list = []
        for i in range(time_part):
            if a <= self.optimal_list[i] - 0.1 * a:
                a_list.append(a)
            else:
                a_list.append(self.optimal_list[i] - 0.1 * a)
            
        model_placement.setObjective(energy_cost , gr.GRB.MINIMIZE)
        for i in range(time_part):
            model_placement.addConstrs(
                    y[i,j] <= self.server_placement[j] for j in range(base_num)
            ) 
            model_placement.addConstrs(
                    sum(u[i,k,j] for j in range(len(list_near[0]))) == a_list[i]
                    for k in range(base_num)
            )
        for k in range(time_part):
            model_placement.addConstrs(
                    sum(u[k,list_assign[i][j],list_near[list_assign[i][j]].index(i)] * final_rebuild[k][list_assign[i][j]] 
                    for j in range(len(list_assign[i]))) <= y[k,i] 
                    for i in range(base_num)
             )
        for k in range(0,time_part-1):
            model_placement.addConstrs(
                    switch[k,j] >= y[k,j] - y[k+1,j] for j in range(base_num)
            ) 
            model_placement.addConstrs(
                    switch[k,j] >= y[k+1,j] - y[k,j] for j in range(base_num)
            ) 
        
        model_placement.addConstr(
                 energy_cost == weight_switch * sum(switch[i,j] for j in range(base_num) for i in range(time_part - 1))
                 + weight_await * sum(y[i,j] for j in range(base_num) for i in range(time_part))
        ) 
           
        model_placement.optimize() 
       
        if model_placement.Status == gr.GRB.OPTIMAL:
            for var in model_placement.getVars():
                list_var.append(var.x)
            result_list = np.array(list_var[0: base_num * time_part])
            result_list = result_list.reshape((time_part,base_num))
            final_optimal = list_var[base_num * time_part : base_num * time_part + 1]
            u_list = np.array(list_var[base_num * time_part + 1 : ])
            u_list = u_list.reshape((time_part, base_num, len(list_near[0])))
            y_list_sum = np.zeros((time_part,base_num))
            for i in range(time_part):
                u_list_temp = u_list[i,:,:]
                u_list_real = list()
                for i in range(len(u_list_temp)):
                    list_new = [0 for i in range(len(u_list_temp))]
                    for j in range(len(list_near[i])):
                        list_new[list_near[i][j]] = u_list_temp[i][j]
                    u_list_real.append(list_new)
                y_list_temp = result_list[i,:]
                y_list,num = GetAccurateFlow(u_list_real, list(y_list_temp),final_rebuild[i],a_list[i])
                y_list_sum[i,:] = np.array(y_list)
               
        else:
            print("no solve")
            return 0,0
        return y_list_sum, final_optimal

    def server_usage(self, y_array):
      
#        list_near, list_far, list_assign = self.location_divide()
        list_near, list_assign = self.find_relative_loca()
        final_rebuild, final_rebuild2 = self.sum_distribution()
        time_part = len(final_rebuild)
        base_num = len(list_near)
        list_var = [] 
        
        model_placement = gr.Model('Lip_placement')
        y = model_placement.addVars(time_part, base_num, vtype='S', lb=0)
        a = model_placement.addVars(time_part, vtype='S', lb=0)  
        u = model_placement.addVars(time_part, base_num, len(list_near[0]), vtype='S', lb=0)                
        
        model_placement.setObjective(sum(a[i] for i in range(time_part)) , gr.GRB.MAXIMIZE)
        for i in range(time_part):
            model_placement.addConstrs(
                    y[i,j] <= y_array[i,j] for j in range(base_num)
            ) 
            model_placement.addConstrs(
                    sum(u[i,k,j] for j in range(len(list_near[0]))) == a[i]
                    for k in range(base_num)
            )
        for k in range(time_part):
            model_placement.addConstrs(
                    sum(u[k,list_assign[i][j],list_near[list_assign[i][j]].index(i)] * final_rebuild[k][list_assign[i][j]] 
                    for j in range(len(list_assign[i]))) <= y[k,i] 
                    for i in range(base_num)
             ) 
        model_placement.optimize() 
       
        if model_placement.Status == gr.GRB.OPTIMAL:
            for var in model_placement.getVars():
                list_var.append(var.x)
            result_list = np.array(list_var[0: base_num * time_part])
            result_list = result_list.reshape((time_part,base_num))
            final_optimal = list_var[base_num * time_part : base_num * time_part + time_part]
            u_list = np.array(list_var[base_num * time_part + time_part : ])
            u_array = u_list.reshape((time_part, base_num, len(list_near[0])))
        return result_list, final_optimal, u_array, list_assign, final_rebuild, base_num, list_near

    def service_failure_statistic(self, y_array, max_rate):
        result_list, final_optimal, u_array, list_assign, final_rebuild, base_num, list_near = self.server_usage(y_array)
        service_failed_num = []
        a = 1 / max_rate
        for k in range(len(final_optimal)):
            failed_num = 0
            if final_optimal[k] >= a:
                service_failed_num.append(failed_num)
                continue
            for i in range(base_num):
                service_temp = 0
                for j in range(len(list_assign[i])):
                    service_temp = service_temp + u_array[k,list_assign[i][j],list_near[list_assign[i][j]].index(i)] * final_rebuild[k][list_assign[i][j]] 
                if service_temp > result_list[k,i] * final_optimal[k]:
                    failed_num = failed_num + 1
            service_failed_num.append(failed_num)
        return np.array(service_failed_num), np.array(service_failed_num) / base_num      